{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:19<00:00, 12.09it/s]\n",
      "  0%|                                                                                          | 0/235 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, train loss 0.6313, train acc 0.747\n",
      "test loss 0.5513, test acc 0.792\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:22<00:00, 10.50it/s]\n",
      "  0%|                                                                                          | 0/235 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, train loss 0.4444, train acc 0.811\n",
      "test loss 0.4733, test acc 0.822\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:16<00:00, 14.02it/s]\n",
      "  1%|█                                                                                 | 3/235 [00:00<00:11, 19.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, train loss 0.5256, train acc 0.825\n",
      "test loss 0.4603, test acc 0.829\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:16<00:00, 14.62it/s]\n",
      "  0%|▎                                                                                 | 1/235 [00:00<00:23,  9.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, train loss 0.4397, train acc 0.831\n",
      "test loss 0.4781, test acc 0.822\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:22<00:00, 10.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, train loss 0.4250, train acc 0.836\n",
      "test loss 0.4499, test acc 0.834\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import tensor\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import torch.utils.data as Data\n",
    "from torch.nn import init\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "mnist_train = torchvision.datasets.FashionMNIST(\n",
    "    root='~/Datasets/FashionMNIST',\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transforms.ToTensor())  # 将所有数据转换为Tensor\n",
    "mnist_test = torchvision.datasets.FashionMNIST(\n",
    "    root='~/Datasets/FashionMNIST',\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transforms.ToTensor()\n",
    ")\n",
    "# 2、通过Dataloader读取小批量数据样本\n",
    "batch_size = 256\n",
    "train_iter = torch.utils.data.DataLoader(\n",
    "    mnist_train,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    num_workers=0\n",
    ")\n",
    "test_iter = torch.utils.data.DataLoader(\n",
    "    mnist_train,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    num_workers=0\n",
    ")\n",
    "\n",
    "num_inputs = 784\n",
    "num_outputs = 10  # 共10类\n",
    "class softmaxnet(torch.nn.Module):\n",
    "    def __init__(self, n_features, n_labels):\n",
    "        super(softmaxnet, self).__init__()\n",
    "        self.linear = torch.nn.Linear(n_features, n_labels)\n",
    "    def forward(self, x):\n",
    "        x_ = x.view((-1, num_inputs))\n",
    "        y_ = self.linear(x_)\n",
    "        return y_\n",
    "    \n",
    "net = softmaxnet(num_inputs, num_outputs)\n",
    "lr = 0.1\n",
    "loss = torch.nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=lr)\n",
    "\n",
    "def get_test_info(data_iter, net):\n",
    "    right_count, all_count = 0.0, 0\n",
    "    for x, y in data_iter:\n",
    "        y_ = net(x)\n",
    "        l = loss(y_, y)\n",
    "        right_count += (y_.argmax(dim=1)==y).sum().item()\n",
    "        all_count += y.shape[0]\n",
    "    return right_count/all_count, l.item()\n",
    "num_epoch = 5\n",
    "\n",
    "\n",
    "for epoch in range(num_epoch):\n",
    "    train_r_num, train_all_num = 0.0, 0\n",
    "    for X, y in tqdm(train_iter):\n",
    "        y_ = net(X)\n",
    "        l = loss(y_, y)\n",
    "        l.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        train_r_num += (y_.argmax(dim=1) == y).sum().item()\n",
    "        train_all_num += y.shape[0]\n",
    "    test_acc, test_ave_loss = get_test_info(test_iter, net)\n",
    "    print('epoch %d, train loss %.4f, train acc %.3f' % (epoch+1, l.item(), train_r_num/train_all_num))\n",
    "    print('test loss %.4f, test acc %.3f' % (test_ave_loss, test_acc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
